{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "#| eval: false\n",
    "! [ -e /content ] && pip install -Uqq fastai  # upgrade fastai on colab"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| default_exp data.block"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "from __future__ import annotations\n",
    "from fastai.torch_basics import *\n",
    "from fastai.data.core import *\n",
    "from fastai.data.load import *\n",
    "from fastai.data.external import *\n",
    "from fastai.data.transforms import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "from nbdev.showdoc import *"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Data block\n",
    "\n",
    "> High level API to quickly get your data in a `DataLoaders`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## TransformBlock -"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "> 📘 **Note**: Several domain-specific blocks such as `ImageBlock`, `BBoxBlock`, `PointBlock`, and `CategoryBlock` are implemented on top of `TransformBlock`. These blocks are designed to handle common tasks in computer vision, classification, and regression. See the [Vision Blocks](https://docs.fast.ai/data.block.html#Vision-blocks) section for more details.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class TransformBlock():\n",
    "    \"A basic wrapper that links defaults transforms for the data block API\"\n",
    "    def __init__(self, \n",
    "        type_tfms:list=None, # One or more `Transform`s\n",
    "        item_tfms:list=None, # `ItemTransform`s, applied on an item\n",
    "        batch_tfms:list=None, # `Transform`s or `RandTransform`s, applied by batch\n",
    "        dl_type:TfmdDL=None, # Task specific `TfmdDL`, defaults to `TfmdDL`\n",
    "        dls_kwargs:dict=None, # Additional arguments to be passed to `DataLoaders`\n",
    "    ):\n",
    "        self.type_tfms  =            L(type_tfms)\n",
    "        self.item_tfms  = ToTensor + L(item_tfms)\n",
    "        self.batch_tfms =            L(batch_tfms)\n",
    "        self.dl_type,self.dls_kwargs = dl_type,({} if dls_kwargs is None else dls_kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def CategoryBlock(\n",
    "    vocab:MutableSequence|pd.Series=None, # List of unique class names\n",
    "    sort:bool=True, # Sort the classes alphabetically\n",
    "    add_na:bool=False, # Add `#na#` to `vocab`\n",
    "):\n",
    "    \"`TransformBlock` for single-label categorical targets\"\n",
    "    return TransformBlock(type_tfms=Categorize(vocab=vocab, sort=sort, add_na=add_na))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def MultiCategoryBlock(\n",
    "    encoded:bool=False, # Whether the data comes in one-hot encoded\n",
    "    vocab:MutableSequence|pd.Series=None, # List of unique class names \n",
    "    add_na:bool=False, # Add `#na#` to `vocab`\n",
    "):\n",
    "    \"`TransformBlock` for multi-label categorical targets\"\n",
    "    tfm = EncodedMultiCategorize(vocab=vocab) if encoded else [MultiCategorize(vocab=vocab, add_na=add_na), OneHotEncode]\n",
    "    return TransformBlock(type_tfms=tfm)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def RegressionBlock(\n",
    "    n_out:int=None, # Number of output values\n",
    "):\n",
    "    \"`TransformBlock` for float targets\"\n",
    "    return TransformBlock(type_tfms=RegressionSetup(c=n_out))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## General API"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "from inspect import isfunction,ismethod"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def _merge_grouper(o):\n",
    "    if isinstance(o, LambdaType): return id(o)\n",
    "    elif isinstance(o, type): return o\n",
    "    elif (isfunction(o) or ismethod(o)): return o.__qualname__\n",
    "    return o.__class__"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def _merge_tfms(*tfms):\n",
    "    \"Group the `tfms` in a single list, removing duplicates (from the same class) and instantiating\"\n",
    "    g = groupby(concat(*tfms), _merge_grouper)\n",
    "    return L(v[-1] for k,v in g.items()).map(instantiate)\n",
    "\n",
    "def _zip(x): return L(x).zip()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#For example, so not exported\n",
    "from fastai.vision.core import *\n",
    "from fastai.vision.data import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "tfms = _merge_tfms([Categorize, MultiCategorize, Categorize(['dog', 'cat'])], Categorize(['a', 'b']))\n",
    "#If there are several instantiated versions, the last one is kept.\n",
    "test_eq(len(tfms), 2)\n",
    "test_eq(tfms[1].__class__, MultiCategorize)\n",
    "test_eq(tfms[0].__class__, Categorize)\n",
    "test_eq(tfms[0].vocab, ['a', 'b'])\n",
    "\n",
    "tfms = _merge_tfms([PILImage.create, PILImage.show])\n",
    "#Check methods are properly separated\n",
    "test_eq(len(tfms), 2)\n",
    "tfms = _merge_tfms([show_image, set_trace])\n",
    "#Check functions are properly separated\n",
    "test_eq(len(tfms), 2)\n",
    "\n",
    "_f = lambda x: 0\n",
    "test_eq(len(_merge_tfms([_f,lambda x: 1])), 2)\n",
    "test_eq(len(_merge_tfms([_f,_f])), 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "@docs\n",
    "@funcs_kwargs\n",
    "class DataBlock():\n",
    "    \"Generic container to quickly build `Datasets` and `DataLoaders`.\"\n",
    "    get_x=get_items=splitter=get_y = None\n",
    "    blocks,dl_type = (TransformBlock,TransformBlock),TfmdDL\n",
    "    _methods = 'get_items splitter get_y get_x'.split()\n",
    "    _msg = \"If you wanted to compose several transforms in your getter don't forget to wrap them in a `Pipeline`.\"\n",
    "    def __init__(self, \n",
    "        blocks:list=None, # One or more `TransformBlock`s\n",
    "        dl_type:TfmdDL=None, # Task specific `TfmdDL`, defaults to `block`'s dl_type or`TfmdDL`\n",
    "        getters:list=None, # Getter functions applied to results of `get_items`\n",
    "        n_inp:int=None, # Number of inputs\n",
    "        item_tfms:list=None, # `ItemTransform`s, applied on an item \n",
    "        batch_tfms:list=None, # `Transform`s or `RandTransform`s, applied by batch\n",
    "        **kwargs, \n",
    "    ):\n",
    "        blocks = L(self.blocks if blocks is None else blocks)\n",
    "        blocks = L(b() if callable(b) else b for b in blocks)\n",
    "        self.type_tfms = blocks.attrgot('type_tfms', L())\n",
    "        self.default_item_tfms  = _merge_tfms(*blocks.attrgot('item_tfms',  L()))\n",
    "        self.default_batch_tfms = _merge_tfms(*blocks.attrgot('batch_tfms', L()))\n",
    "        for b in blocks:\n",
    "            if getattr(b, 'dl_type', None) is not None: self.dl_type = b.dl_type\n",
    "        if dl_type is not None: self.dl_type = dl_type\n",
    "        self.dataloaders = delegates(self.dl_type.__init__)(self.dataloaders)\n",
    "        self.dls_kwargs = merge(*blocks.attrgot('dls_kwargs', {}))\n",
    "\n",
    "        self.n_inp = ifnone(n_inp, max(1, len(blocks)-1))\n",
    "        self.getters = ifnone(getters, [noop]*len(self.type_tfms))\n",
    "        if self.get_x:\n",
    "            if len(L(self.get_x)) != self.n_inp:\n",
    "                raise ValueError(f'get_x contains {len(L(self.get_x))} functions, but must contain {self.n_inp} (one for each input)\\n{self._msg}')\n",
    "            self.getters[:self.n_inp] = L(self.get_x)\n",
    "        if self.get_y:\n",
    "            n_targs = len(self.getters) - self.n_inp\n",
    "            if len(L(self.get_y)) != n_targs:\n",
    "                raise ValueError(f'get_y contains {len(L(self.get_y))} functions, but must contain {n_targs} (one for each target)\\n{self._msg}')\n",
    "            self.getters[self.n_inp:] = L(self.get_y)\n",
    "\n",
    "        if kwargs: raise TypeError(f'invalid keyword arguments: {\", \".join(kwargs.keys())}')\n",
    "        self.new(item_tfms, batch_tfms)\n",
    "\n",
    "    def _combine_type_tfms(self): return L([self.getters, self.type_tfms]).map_zip(\n",
    "        lambda g,tt: (g.fs if isinstance(g, Pipeline) else L(g)) + tt)\n",
    "\n",
    "    def new(self, \n",
    "        item_tfms:list=None, # `ItemTransform`s, applied on an item\n",
    "        batch_tfms:list=None, # `Transform`s or `RandTransform`s, applied by batch \n",
    "    ):\n",
    "        self.item_tfms  = _merge_tfms(self.default_item_tfms,  item_tfms)\n",
    "        self.batch_tfms = _merge_tfms(self.default_batch_tfms, batch_tfms)\n",
    "        return self\n",
    "\n",
    "    @classmethod\n",
    "    def from_columns(cls, \n",
    "        blocks:list =None, # One or more `TransformBlock`s\n",
    "        getters:list =None, # Getter functions applied to results of `get_items`\n",
    "        get_items:Callable=None, # A function to get items\n",
    "        **kwargs,\n",
    "    ):\n",
    "        if getters is None: getters = L(ItemGetter(i) for i in range(2 if blocks is None else len(L(blocks))))\n",
    "        get_items = _zip if get_items is None else compose(get_items, _zip)\n",
    "        return cls(blocks=blocks, getters=getters, get_items=get_items, **kwargs)\n",
    "\n",
    "    def datasets(self, \n",
    "        source, # The data source\n",
    "        verbose:bool=False, # Show verbose messages\n",
    "    ) -> Datasets:\n",
    "        self.source = source                     ; pv(f\"Collecting items from {source}\", verbose)\n",
    "        items = (self.get_items or noop)(source) ; pv(f\"Found {len(items)} items\", verbose)\n",
    "        splits = (self.splitter or RandomSplitter())(items)\n",
    "        pv(f\"{len(splits)} datasets of sizes {','.join([str(len(s)) for s in splits])}\", verbose)\n",
    "        return Datasets(items, tfms=self._combine_type_tfms(), splits=splits, dl_type=self.dl_type, n_inp=self.n_inp, verbose=verbose)\n",
    "\n",
    "    def dataloaders(self, \n",
    "        source, # The data source\n",
    "        path:str='.', # Data source and default `Learner` path \n",
    "        verbose:bool=False, # Show verbose messages\n",
    "        **kwargs\n",
    "    ) -> DataLoaders:\n",
    "        dsets = self.datasets(source, verbose=verbose)\n",
    "        kwargs = {**self.dls_kwargs, **kwargs, 'verbose': verbose}\n",
    "        return dsets.dataloaders(path=path, after_item=self.item_tfms, after_batch=self.batch_tfms, **kwargs)\n",
    "\n",
    "    _docs = dict(new=\"Create a new `DataBlock` with other `item_tfms` and `batch_tfms`\",\n",
    "                 datasets=\"Create a `Datasets` object from `source`\",\n",
    "                 dataloaders=\"Create a `DataLoaders` object from `source`\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To build a `DataBlock` you need to give the library four things: the types of your input/labels, and at least two functions: `get_items` and `splitter`. You may also need to include `get_x` and `get_y` or a more generic list of `getters` that are applied to the results of `get_items`.\n",
    "\n",
    "splitter is a callable which, when called with `items`, returns a tuple of iterables representing the indices of the training and validation data.\n",
    "\n",
    "Once those are provided, you automatically get a `Datasets` or a `DataLoaders`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "---\n",
       "\n",
       "[source](https://github.com/fastai/fastai/blob/main/fastai/data/block.py#L141){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
       "\n",
       "### DataBlock.datasets\n",
       "\n",
       ">      DataBlock.datasets (source, verbose:bool=False)\n",
       "\n",
       "*Create a `Datasets` object from `source`*\n",
       "\n",
       "|    | **Type** | **Default** | **Details** |\n",
       "| -- | -------- | ----------- | ----------- |\n",
       "| source |  |  | The data source |\n",
       "| verbose | bool | False | Show verbose messages |\n",
       "| **Returns** | **Datasets** |  |  |"
      ],
      "text/plain": [
       "---\n",
       "\n",
       "[source](https://github.com/fastai/fastai/blob/main/fastai/data/block.py#L141){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
       "\n",
       "### DataBlock.datasets\n",
       "\n",
       ">      DataBlock.datasets (source, verbose:bool=False)\n",
       "\n",
       "*Create a `Datasets` object from `source`*\n",
       "\n",
       "|    | **Type** | **Default** | **Details** |\n",
       "| -- | -------- | ----------- | ----------- |\n",
       "| source |  |  | The data source |\n",
       "| verbose | bool | False | Show verbose messages |\n",
       "| **Returns** | **Datasets** |  |  |"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "show_doc(DataBlock.datasets)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "---\n",
       "\n",
       "[source](https://github.com/fastai/fastai/blob/main/fastai/data/block.py#L151){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
       "\n",
       "### DataBlock.dataloaders\n",
       "\n",
       ">      DataBlock.dataloaders (source, path:str='.', verbose:bool=False,\n",
       ">                             bs:int=64, shuffle:bool=False,\n",
       ">                             num_workers:int=None, do_setup:bool=True,\n",
       ">                             pin_memory=False, timeout=0, batch_size=None,\n",
       ">                             drop_last=False, indexed=None, n=None,\n",
       ">                             device=None, persistent_workers=False,\n",
       ">                             pin_memory_device='', wif=None, before_iter=None,\n",
       ">                             after_item=None, before_batch=None,\n",
       ">                             after_batch=None, after_iter=None,\n",
       ">                             create_batches=None, create_item=None,\n",
       ">                             create_batch=None, retain=None, get_idxs=None,\n",
       ">                             sample=None, shuffle_fn=None, do_batch=None)\n",
       "\n",
       "*Create a `DataLoaders` object from `source`*\n",
       "\n",
       "|    | **Type** | **Default** | **Details** |\n",
       "| -- | -------- | ----------- | ----------- |\n",
       "| source |  |  | The data source |\n",
       "| path | str | . | Data source and default `Learner` path |\n",
       "| verbose | bool | False | Show verbose messages |\n",
       "| bs | int | 64 | Size of batch |\n",
       "| shuffle | bool | False | Whether to shuffle data |\n",
       "| num_workers | int | None | Number of CPU cores to use in parallel (default: All available up to 16) |\n",
       "| do_setup | bool | True | Whether to run `setup()` for batch transform(s) |\n",
       "| pin_memory | bool | False |  |\n",
       "| timeout | int | 0 |  |\n",
       "| batch_size | NoneType | None |  |\n",
       "| drop_last | bool | False |  |\n",
       "| indexed | NoneType | None |  |\n",
       "| n | NoneType | None |  |\n",
       "| device | NoneType | None |  |\n",
       "| persistent_workers | bool | False |  |\n",
       "| pin_memory_device | str |  |  |\n",
       "| wif | NoneType | None |  |\n",
       "| before_iter | NoneType | None |  |\n",
       "| after_item | NoneType | None |  |\n",
       "| before_batch | NoneType | None |  |\n",
       "| after_batch | NoneType | None |  |\n",
       "| after_iter | NoneType | None |  |\n",
       "| create_batches | NoneType | None |  |\n",
       "| create_item | NoneType | None |  |\n",
       "| create_batch | NoneType | None |  |\n",
       "| retain | NoneType | None |  |\n",
       "| get_idxs | NoneType | None |  |\n",
       "| sample | NoneType | None |  |\n",
       "| shuffle_fn | NoneType | None |  |\n",
       "| do_batch | NoneType | None |  |\n",
       "| **Returns** | **DataLoaders** |  |  |"
      ],
      "text/plain": [
       "---\n",
       "\n",
       "[source](https://github.com/fastai/fastai/blob/main/fastai/data/block.py#L151){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
       "\n",
       "### DataBlock.dataloaders\n",
       "\n",
       ">      DataBlock.dataloaders (source, path:str='.', verbose:bool=False,\n",
       ">                             bs:int=64, shuffle:bool=False,\n",
       ">                             num_workers:int=None, do_setup:bool=True,\n",
       ">                             pin_memory=False, timeout=0, batch_size=None,\n",
       ">                             drop_last=False, indexed=None, n=None,\n",
       ">                             device=None, persistent_workers=False,\n",
       ">                             pin_memory_device='', wif=None, before_iter=None,\n",
       ">                             after_item=None, before_batch=None,\n",
       ">                             after_batch=None, after_iter=None,\n",
       ">                             create_batches=None, create_item=None,\n",
       ">                             create_batch=None, retain=None, get_idxs=None,\n",
       ">                             sample=None, shuffle_fn=None, do_batch=None)\n",
       "\n",
       "*Create a `DataLoaders` object from `source`*\n",
       "\n",
       "|    | **Type** | **Default** | **Details** |\n",
       "| -- | -------- | ----------- | ----------- |\n",
       "| source |  |  | The data source |\n",
       "| path | str | . | Data source and default `Learner` path |\n",
       "| verbose | bool | False | Show verbose messages |\n",
       "| bs | int | 64 | Size of batch |\n",
       "| shuffle | bool | False | Whether to shuffle data |\n",
       "| num_workers | int | None | Number of CPU cores to use in parallel (default: All available up to 16) |\n",
       "| do_setup | bool | True | Whether to run `setup()` for batch transform(s) |\n",
       "| pin_memory | bool | False |  |\n",
       "| timeout | int | 0 |  |\n",
       "| batch_size | NoneType | None |  |\n",
       "| drop_last | bool | False |  |\n",
       "| indexed | NoneType | None |  |\n",
       "| n | NoneType | None |  |\n",
       "| device | NoneType | None |  |\n",
       "| persistent_workers | bool | False |  |\n",
       "| pin_memory_device | str |  |  |\n",
       "| wif | NoneType | None |  |\n",
       "| before_iter | NoneType | None |  |\n",
       "| after_item | NoneType | None |  |\n",
       "| before_batch | NoneType | None |  |\n",
       "| after_batch | NoneType | None |  |\n",
       "| after_iter | NoneType | None |  |\n",
       "| create_batches | NoneType | None |  |\n",
       "| create_item | NoneType | None |  |\n",
       "| create_batch | NoneType | None |  |\n",
       "| retain | NoneType | None |  |\n",
       "| get_idxs | NoneType | None |  |\n",
       "| sample | NoneType | None |  |\n",
       "| shuffle_fn | NoneType | None |  |\n",
       "| do_batch | NoneType | None |  |\n",
       "| **Returns** | **DataLoaders** |  |  |"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#| echo: false\n",
    "dblock = DataBlock()\n",
    "show_doc(dblock.dataloaders, name=\"DataBlock.dataloaders\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can create a `DataBlock` by passing functions:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mnist = DataBlock(blocks = (ImageBlock(cls=PILImageBW),CategoryBlock),\n",
    "                  get_items = get_image_files,\n",
    "                  splitter = GrandparentSplitter(),\n",
    "                  get_y = parent_label)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Each type comes with default transforms that will be applied:\n",
    "\n",
    "- at the base level to create items in a tuple (usually input,target) from the base elements (like filenames)\n",
    "- at the item level of the datasets\n",
    "- at the batch level\n",
    "\n",
    "They are called respectively type transforms, item transforms, batch transforms. In the case of MNIST, the type transforms are the method to create a `PILImageBW` (for the input) and the `Categorize` transform (for the target), the item transform is `ToTensor` and the batch transforms are `Cuda` and `IntToFloatTensor`. You can add any other transforms by passing them in `DataBlock.datasets` or `DataBlock.dataloaders`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(mnist.type_tfms[0], [PILImageBW.create])\n",
    "test_eq(mnist.type_tfms[1].map(type), [Categorize])\n",
    "test_eq(mnist.default_item_tfms.map(type), [ToTensor])\n",
    "test_eq(mnist.default_batch_tfms.map(type), [IntToFloatTensor])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAHsAAACLCAYAAABBVeZmAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8/fFQqAAAACXBIWXMAAAsTAAALEwEAmpwYAAAEbklEQVR4nO2dPUscURSGz1ULMSuKWAiBrYIgNloIEkRQQawl9aawsRARRBC0EOwF3c5GsggKQvAHqKDZRrFVtFHsTLRQEvBznRTBJecmrqvO1933fUCYlxl3z/p458zHzozxPE8IBmVRF0DCg7KBoGwgKBsIygaCsoGgbCCgZBtjflk/OWNMOuq6wqIi6gLCxPO8xOO0MeadiHwXkZXoKgoXqJFt8UlEfojIt6gLCQtk2Z9FJOMBHS82QJ81jzEmKSLHIvLB87zjqOsJC9SRnRKRLJJoEWzZX6IuImzgZBtjPorIewHaCn8ETrb82TD76nnez6gLCRvIDTRUEEc2LJQNBGUDQdlAUDYQz5314qa6e5inZnBkA0HZQFA2EJQNBGUDQdlAUDYQlA0EZQNB2UBQNhCUDQRlA0HZQFA2EJQNBGUDQdlAUDYQlA0EZQNB2UBQNhCUDQRlA0HZQDhz07v9/X2VM5mMysvLyypXVOiPNjw8nJ8eGRlR81paWlTu7e19ZZX/x3697u5uX1+/WDiygaBsICgbiOduoBObS3YHBgZUXlhYiKiSl9PU1KTy3t5ekG/HS3YJZUNB2UA407M3NzdV7uvrU/nm5ibMcl5EQ0ODyoeHhypXV1f7+Xbs2YSyoaBsIJzp2TaXl5cqr6+vq3xyclL0a9n7vc3NzSpfXFyoPD09rbL9NzRGt822tjaVt7e3i67tFbBnE8qGgrKBcOZ8tk1NTY3K/f39gb2X3aNt7B5tn79eXFz0vabXwJENBGUDQdlAONuzg+T+/l7lpaWlgsvX1taqnEqlVK6vr/elrrfCkQ0EZQNB2UCwZ8u/x7ZHR0dVPjg4KPj7Kyv6SY89PT3+FOYzHNlAUDYQkKtxe9fK/lpyOp0u+PszMzMqR3U5z0vhyAaCsoGgbCCc/VrSW8hmsyp3dnYWXD6RSKh8enqqclVVlT+F+QO/lkQoGwrKBgJmP/vo6Cg/3dXVVXBZu0dvbW2pHLMeXTQc2UBQNhCUDUTJ9uyrqyuVBwcH89O5XE7NKyvT//N2j7ZvneUqHNlAUDYQlA1Eyfbs8fFxldfW1p5cdmhoSOVS6dE2HNlAUDYQlA1EyZzP3t3dVbm9vV3lh4eH/HRHR4eat7GxobJ9+2rH4PlsQtlQUDYQzjanu7s7lcfGxlT+u0eLiCSTyfz07Oysmud4jy4ajmwgKBsIygbC2WZlP67JvkV1eXm5ynNzc/np1tbWwOqKMxzZQFA2EJQNhDPHxu1bStuPT7q+vlbZvgXlxMREMIXFDx4bJ5QNBWUDEduefXt7q7J9DfXOzo7Kjl1DHSTs2YSyoYjt4dL5+XmV7dW2zerqqspAq+2i4cgGgrKBoGwgYrPrdX5+rnJjY6PK9lPzJicnVZ6amlLZvgwXCO56EcqGgrKBiE3PPjs7U9nu2XV1dSrbT8atrKwMpjD3YM8mlA0FZQMRm55NfIM9m1A2FJQNxHPns59c/xP34MgGgrKBoGwgKBsIygaCsoH4Dd9p/n3ENZhPAAAAAElFTkSuQmCC",
      "text/plain": [
       "<Figure size 144x144 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "dsets = mnist.datasets(untar_data(URLs.MNIST_TINY))\n",
    "test_eq(dsets.vocab, ['3', '7'])\n",
    "x,y = dsets.train[0]\n",
    "test_eq(x.size,(28,28))\n",
    "show_at(dsets.train, 0, cmap='Greys', figsize=(2,2));"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_fail(lambda: DataBlock(wrong_kwarg=42, wrong_kwarg2='foo'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can pass any number of blocks to `DataBlock`, we can then define what are the input and target blocks by changing `n_inp`. For example, defining `n_inp=2` will consider the first two blocks passed as inputs and the others as targets. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mnist = DataBlock((ImageBlock, ImageBlock, CategoryBlock), get_items=get_image_files, splitter=GrandparentSplitter(),\n",
    "                   get_y=parent_label)\n",
    "dsets = mnist.datasets(untar_data(URLs.MNIST_TINY))\n",
    "test_eq(mnist.n_inp, 2)\n",
    "test_eq(len(dsets.train[0]), 3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_fail(lambda: DataBlock((ImageBlock, ImageBlock, CategoryBlock), get_items=get_image_files, splitter=GrandparentSplitter(),\n",
    "                  get_y=[parent_label, noop],\n",
    "                  n_inp=2), msg='get_y contains 2 functions, but must contain 1 (one for each output)')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mnist = DataBlock((ImageBlock, ImageBlock, CategoryBlock), get_items=get_image_files, splitter=GrandparentSplitter(),\n",
    "                  n_inp=1,\n",
    "                  get_y=[noop, Pipeline([noop, parent_label])])\n",
    "dsets = mnist.datasets(untar_data(URLs.MNIST_TINY))\n",
    "test_eq(len(dsets.train[0]), 3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Debugging"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def _short_repr(x):\n",
    "    if isinstance(x, tuple): return f'({\", \".join([_short_repr(y) for y in x])})'\n",
    "    if isinstance(x, list): return f'[{\", \".join([_short_repr(y) for y in x])}]'\n",
    "    if not isinstance(x, Tensor): return str(x)\n",
    "    if x.numel() <= 20 and x.ndim <=1: return str(x)\n",
    "    return f'{x.__class__.__name__} of size {\"x\".join([str(d) for d in x.shape])}'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "test_eq(_short_repr(TensorImage(torch.randn(40,56))), 'TensorImage of size 40x56')\n",
    "test_eq(_short_repr(TensorCategory([1,2,3])), 'TensorCategory([1, 2, 3])')\n",
    "test_eq(_short_repr((TensorImage(torch.randn(40,56)), TensorImage(torch.randn(32,20)))),\n",
    "        '(TensorImage of size 40x56, TensorImage of size 32x20)')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def _apply_pipeline(p, x):\n",
    "    print(f\"  {p}\\n    starting from\\n      {_short_repr(x)}\")\n",
    "    for f in p.fs:\n",
    "        name = f.name\n",
    "        try:\n",
    "            x = f(x)\n",
    "            if name != \"noop\": print(f\"    applying {name} gives\\n      {_short_repr(x)}\")\n",
    "        except Exception as e:\n",
    "            print(f\"    applying {name} failed.\")\n",
    "            raise e\n",
    "    return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "from fastai.data.load import _collate_types\n",
    "\n",
    "def _find_fail_collate(s):\n",
    "    s = L(*s)\n",
    "    for x in s[0]:\n",
    "        if not isinstance(x, _collate_types): return f\"{type(x).__name__} is not collatable\"\n",
    "    for i in range_of(s[0]):\n",
    "        try: _ = default_collate(s.itemgot(i))\n",
    "        except:\n",
    "            shapes = [getattr(o[i], 'shape', None) for o in s]\n",
    "            return f\"Could not collate the {i}-th members of your tuples because got the following shapes\\n{','.join([str(s) for s in shapes])}\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "@patch\n",
    "def summary(self:DataBlock,\n",
    "    source, # The data source  \n",
    "    bs:int=4, # The batch size\n",
    "    show_batch:bool=False, # Call `show_batch` after the summary\n",
    "    **kwargs, # Additional keyword arguments to `show_batch`\n",
    "):\n",
    "    \"Steps through the transform pipeline for one batch, and optionally calls `show_batch(**kwargs)` on the transient `Dataloaders`.\"\n",
    "    print(f\"Setting-up type transforms pipelines\")\n",
    "    dsets = self.datasets(source, verbose=True)\n",
    "    print(\"\\nBuilding one sample\")\n",
    "    for tl in dsets.train.tls:\n",
    "        _apply_pipeline(tl.tfms, get_first(dsets.train.items))\n",
    "    print(f\"\\nFinal sample: {dsets.train[0]}\\n\\n\")\n",
    "\n",
    "    dls = self.dataloaders(source, bs=bs, verbose=True)\n",
    "    print(\"\\nBuilding one batch\")\n",
    "    if len([f for f in dls.train.after_item.fs if f.name != 'noop'])!=0:\n",
    "        print(\"Applying item_tfms to the first sample:\")\n",
    "        s = [_apply_pipeline(dls.train.after_item, dsets.train[0])]\n",
    "        print(f\"\\nAdding the next {bs-1} samples\")\n",
    "        s += [dls.train.after_item(dsets.train[i]) for i in range(1, bs)]\n",
    "    else:\n",
    "        print(\"No item_tfms to apply\")\n",
    "        s = [dls.train.after_item(dsets.train[i]) for i in range(bs)]\n",
    "\n",
    "    if len([f for f in dls.train.before_batch.fs if f.name != 'noop'])!=0:\n",
    "        print(\"\\nApplying before_batch to the list of samples\")\n",
    "        s = _apply_pipeline(dls.train.before_batch, s)\n",
    "    else: print(\"\\nNo before_batch transform to apply\")\n",
    "\n",
    "    print(\"\\nCollating items in a batch\")\n",
    "    try:\n",
    "        b = dls.train.create_batch(s)\n",
    "        b = retain_types(b, s[0] if is_listy(s) else s)\n",
    "    except Exception as e:\n",
    "        print(\"Error! It's not possible to collate your items in a batch\")\n",
    "        why = _find_fail_collate(s)\n",
    "        print(\"Make sure all parts of your samples are tensors of the same size\" if why is None else why)\n",
    "        raise e\n",
    "\n",
    "    if len([f for f in dls.train.after_batch.fs if f.name != 'noop'])!=0:\n",
    "        print(\"\\nApplying batch_tfms to the batch built\")\n",
    "        b = to_device(b, dls.device)\n",
    "        b = _apply_pipeline(dls.train.after_batch, b)\n",
    "    else: print(\"\\nNo batch_tfms to apply\")\n",
    "\n",
    "    if show_batch: dls.show_batch(**kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "---\n",
       "\n",
       "[source](https://github.com/fastai/fastai/blob/main/fastai/data/block.py#L201){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
       "\n",
       "### DataBlock.summary\n",
       "\n",
       ">      DataBlock.summary (source, bs:int=4, show_batch:bool=False, **kwargs)\n",
       "\n",
       "*Steps through the transform pipeline for one batch, and optionally calls `show_batch(**kwargs)` on the transient `Dataloaders`.*\n",
       "\n",
       "|    | **Type** | **Default** | **Details** |\n",
       "| -- | -------- | ----------- | ----------- |\n",
       "| source |  |  | The data source |\n",
       "| bs | int | 4 | The batch size |\n",
       "| show_batch | bool | False | Call `show_batch` after the summary |\n",
       "| kwargs | VAR_KEYWORD |  |  |"
      ],
      "text/plain": [
       "---\n",
       "\n",
       "[source](https://github.com/fastai/fastai/blob/main/fastai/data/block.py#L201){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
       "\n",
       "### DataBlock.summary\n",
       "\n",
       ">      DataBlock.summary (source, bs:int=4, show_batch:bool=False, **kwargs)\n",
       "\n",
       "*Steps through the transform pipeline for one batch, and optionally calls `show_batch(**kwargs)` on the transient `Dataloaders`.*\n",
       "\n",
       "|    | **Type** | **Default** | **Details** |\n",
       "| -- | -------- | ----------- | ----------- |\n",
       "| source |  |  | The data source |\n",
       "| bs | int | 4 | The batch size |\n",
       "| show_batch | bool | False | Call `show_batch` after the summary |\n",
       "| kwargs | VAR_KEYWORD |  |  |"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "show_doc(DataBlock.summary)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Besides stepping through the transformation, `summary()`  provides a shortcut `dls.show_batch(...)`, to see the data.  E.g.\n",
    "\n",
    "```\n",
    "pets.summary(path/\"images\", bs=8, show_batch=True, unique=True,...)\n",
    "```\n",
    "\n",
    "is a shortcut to:\n",
    "```\n",
    "pets.summary(path/\"images\", bs=8)\n",
    "dls = pets.dataloaders(path/\"images\", bs=8)\n",
    "dls.show_batch(unique=True,...)  # See different tfms effect on the same image.\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Export -"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "from nbdev import nbdev_export\n",
    "nbdev_export()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "jupytext": {
   "split_at_heading": true
  },
  "kernelspec": {
   "display_name": "python3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
